function Results = kass_method(raster,times,params)
%%%%%%%%%%%%%%%%%%%% input

%%%%%%%%%%%% raster(ms) [nx1 cell]:
% n session cells of trials containing spike time array,
% aligned by time of stimulus on zero

%%%%%%%%%%%% times(ms) [1xt]
% time vector of evaluation times.
% in range [bin_width/2, end-bin_width/2]
% This function does not support sliding window, please use var_decom.
% example: 100:100:1000

%%%%%%%%%%%% params
%%% bin_width(ms) [double]
%%% K [ind]: number of correlated samples
%%% bin_count_kass
%%% smooth_bin
%%% cv_count [int] (optional, default: 100):
% number of cross-validations
%%% keep_indx(session number) [mxt] (optional):
% specifies selected sessions/neurons for each time point

%%%%%%%%%%%%%%%%%%%% OUTPUT
%%%% fr:        firing rate [tx1]
%%%% fr_SE:     SE of firing rate [tx1]
%%%% FF:        fano factor [tx1]
%%%% FF_SE:     SE of fano factor [tx1]
%%%% Var:       Variance [tx1]
%%%% Var_SE:    SE of Variance [tx1]

%%%% VEC:       Kass VEC [tx1]
%%%% VEC_SE:    SE of Kass VEC [tx1]
%%%% nRV:      Kass FFS [tx1]
%%%% nRV_SE:   SE of Kass FFS [tx1]
%%%% EVC:       Kass EVC [tx1]
%%%% EVC_SE:    SE of Kass EVC [tx1]
%%%% nSI:      Kass FF0 [tx1]
%%%% nSI_SE:   SE of Kass FF0 [tx1]

% Citations: Saleh F, Fakharian M & Ghazizadeh A "Stimulus presentation can enhance spiking irregularity across subcortical and cortical regions." 2021
% G. Vinci, V. Ventura, M. A. Smith, R. E. Kass, Separating spike count correlation from firing rate correlation. Neural Comput. 28, 2709?2733 (2016)

n = length(raster);

% Check Inputs
try
    bin_width = params.bin_width;
catch
    error('bin_width is required!');
end
try
    K = params.K;
    flag_K_adaptive = 0;
    cv2_all_neu = ones(n,length(times));
catch
    flag_K_adaptive = 1;
    K = nan;
    cv2_all_neu = params.cv2_all_neu;
end
try
    bin_count_kass = params.bin_count_kass;
catch
    bin_count_kass = 100;
end
try
    smooth_bin = params.smooth_bin;
catch
    smooth_bin = round(bin_count_kass/20);
end
if(~isfield(params,'keep_indx'))
    params.keep_indx = repmat(1:n,length(times),1)';
end
if(~isfield(params,'cv_count'))
    params.cv_count = 100;
end

keep_indx = params.keep_indx;
cv_count = params.cv_count;

nRV_tot = zeros(n,length(times));
nSI_tot = zeros(n,length(times));
VEC_tot  = zeros(n,length(times));
EVC_tot  = zeros(n,length(times));

nSI_SE = zeros(1,length(times));
nRV_SE = zeros(1,length(times));
EVC_SE  = zeros(1,length(times));
VEC_SE  = zeros(1,length(times));

fr_tot  = zeros(n,length(times));
var_tot = zeros(n,length(times));
FF_tot  = zeros(n,length(times));

fr_SE  = zeros(1,length(times));
var_SE = zeros(1,length(times));
FF_SE  = zeros(1,length(times));

parfor i = 1:n
    raster_local = raster{i};
    
    % Cross Validation
    nRV_bin_cv = zeros(cv_count,length(times));
    nSI_bin_cv = zeros(cv_count,length(times));
    VEC_bin_cv  = zeros(cv_count,length(times));
    EVC_bin_cv  = zeros(cv_count,length(times));
    
    raster_local = cellfun(@(x) reshape(x,length(x),1),raster_local,...
        'UniformOutput',false);
    spike_cell = cellfun(@(x) histcounts...
        (x,times(1)-bin_width/2:bin_width:times(end)+bin_width/2),...
        raster_local,'UniformOutput',false);
    spike_mat_bin = cell2mat(spike_cell);
    mean_botc = bootstrp(cv_count,@(x) nanmean(x),spike_mat_bin);
    var_botc = bootstrp(cv_count,@(x) nanvar(x),spike_mat_bin);
    fano_botc =  bootstrp(cv_count,@(x) ...
        nanvar(x)./nanmean(x),spike_mat_bin);
    
    fr_tot(i,:)  = nanmean(mean_botc,1)/(bin_width*1e-3);
    var_tot(i,:) =  nanmean(var_botc,1);
    FF_tot(i,:)  =  nanmean(fano_botc,1);
    
    if n == 1
        fr_SE(i,:)  = nanstd(mean_botc,0,1)/...
            (bin_width*1e-3)/sqrt(cv_count);
        var_SE(i,:) = nanstd(var_botc,1)/sqrt(cv_count);
        FF_SE(i,:)  = nanstd(fano_botc,1)/sqrt(cv_count);
    end
    
    neu_fir = fr_tot(i,:);
    
    
    for t=1:length(times)
        
        bin_edge = linspace(times(t)-...
            bin_width/2,times(t)+bin_width/2,bin_count_kass+1);
        spike_cell = cellfun(@(x) histcounts(x,bin_edge),...
            raster_local,'UniformOutput',false);
        spike_mat = cell2mat(spike_cell);
        if(smooth_bin <= 0)
            prob_bin_time_smoothed = ones(bin_count_kass,1) ...
                * 1/bin_count_kass;
        else
            prob_bin_time = nansum(spike_mat,1) / nansum(sum(spike_mat));
            prob_bin_time_smoothed = movmean(prob_bin_time,smooth_bin);
            prob_bin_time_smoothed = ...
                prob_bin_time_smoothed/sum(prob_bin_time_smoothed);
        end
        if flag_K_adaptive
           neu_cv = cv2_all_neu(i,:);
           neu_K = ceil((prep_effcorr_ms_data(neu_cv(t),...
               neu_fir(t),bin_width)/bin_width)*bin_count_kass);
           %neu_K
        else 
           neu_K = K;
        end
        nSI_bin_cv(:,t) = ...
            bootstrp(cv_count,@(x) ...
            Phi_kass_opt(x,neu_K,prob_bin_time_smoothed),spike_mat);
        phi = nanmean(nSI_bin_cv(:,t));
        VEC_bin_cv(:,t) = bootstrp(cv_count,@(x) nanvar(x,0,1) - ...
            phi *  nanmean(x),spike_mat_bin(:,t));
        nRV_bin_cv(:,t) = bootstrp(cv_count,@(x) ...
            (nanvar(x,0,1) - phi*nanmean(x))/(nanmean(x)*bin_width*1e-3),...
            spike_mat_bin(:,t));
        EVC_bin_cv(:,t) = bootstrp(cv_count,@(x) phi*...
            nanmean(x),spike_mat_bin(:,t));
    end
    
    nSI_bin_cv(nSI_bin_cv == inf) = NaN;
    nSI_bin_cv(nSI_bin_cv == -inf) = NaN;
    nRV_bin_cv(nRV_bin_cv == inf) = NaN;
    nRV_bin_cv(nRV_bin_cv == -inf) = NaN;
    EVC_bin_cv(EVC_bin_cv == inf) = NaN;
    EVC_bin_cv(EVC_bin_cv == -inf) = NaN;
    VEC_bin_cv(VEC_bin_cv == inf) = NaN;
    VEC_bin_cv(VEC_bin_cv == -inf) = NaN;
    
    if n == 1
        nSI_SE(i,:) = nanstd(nSI_bin_cv,[],1);
        nRV_SE(i,:) = nanstd(nRV_bin_cv,[],1);
        EVC_SE(i,:)  = nanstd(EVC_bin_cv,[],1);
        VEC_SE(i,:)  = nanstd(VEC_bin_cv,[],1);
    end
    
    nSI_tot(i,:) = nanmean(nSI_bin_cv,1);
    VEC_tot(i,:)  = nanmean(VEC_bin_cv,1);
    nRV_tot(i,:) = nanmean(nRV_bin_cv,1);
    EVC_tot(i,:)  = nanmean(EVC_bin_cv,1);

end

if n == 1
    Results.nSI = nSI_tot;
    Results.nRV = nRV_tot;
    Results.EVC  = EVC_tot;
    Results.VEC  = VEC_tot;
    
    Results.nSI_SE = nSI_SE;
    Results.nRV_SE = nRV_SE;
    Results.EVC_SE  = EVC_SE;
    Results.VEC_SE  = VEC_SE;
    
    Results.fr  = fr_tot;
    Results.FF  = FF_tot;
    Results.Var = var_tot;
    
    Results.fr_SE  = fr_SE;
    Results.FF_SE  = FF_SE;
    Results.Var_SE = var_SE;
else
    for t=1:length(times)
        nn = size(keep_indx,1);
        
        Results.nSI(t) = nanmean(nSI_tot(keep_indx(:,t),t),1);
        Results.nRV(t) = nanmean(nRV_tot(keep_indx(:,t),t),1);
        Results.EVC(t)  = nanmean(EVC_tot(keep_indx(:,t),t),1);
        Results.VEC(t)  = nanmean(VEC_tot(keep_indx(:,t),t),1);

        Results.nSI_SE(t) = nanstd(nSI_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
        Results.nRV_SE(t) = nanstd(nRV_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
        Results.EVC_SE(t)  = nanstd(EVC_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
        Results.VEC_SE(t)  = nanstd(VEC_tot(keep_indx(:,t),t),0,1)/sqrt(nn);

        Results.fr(t)  = nanmean(fr_tot(keep_indx(:,t),t),1);
        Results.FF(t)  = nanmean(FF_tot(keep_indx(:,t),t),1);
        Results.Var(t) = nanmean(var_tot(keep_indx(:,t),t),1);

        Results.fr_SE(t)  = nanstd(fr_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
        Results.FF_SE(t)  = nanstd(FF_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
        Results.Var_SE(t) = nanstd(var_tot(keep_indx(:,t),t),0,1)/sqrt(nn);
    end
    
end

end

function phi = Phi_kass_opt(Ybins,K,p)
n = size(Ybins,1);
m = size(Ybins,2);
p=reshape(p,1,m);
P = p' * p; 
p_sum = sum(arrayfun(@(d) sum(diag(P, d)), -K:1:K ));
gmk_dem = n * (1-p_sum);

gmk_nom = 0;
non_zero_trials = find(sum(Ybins,2));
for r_num=1:length(non_zero_trials)
    r = non_zero_trials(r_num);
    Yr = sum(Ybins(r,:));
    aj = Ybins(r,:) - p * Yr;%% Yrj - pj * Yr
    A = aj' * aj;
    gmk_nom = gmk_nom + sum(arrayfun(@(d) sum(diag(A, d)), -K:1:K ));
end

gmk = gmk_nom/gmk_dem;
phi = gmk/(mean(sum(Ybins,2)));
end
